/**
 * @file   MultigridSolver.h
 * @author ych <ych@Ubuntu18-04>
 * @date   Sun Jun 27 12:20:25 2021
 * 
 * @brief  
 * 
 * 
 */
#ifndef MultigridSolver_h_
#define MultigridSolver_h_
#include <iostream>
#include <Eigen/Sparse>
#include "Mesh.h"
#include <functional>
#define Multi2(a,bit) (a<<bit)
#define Devide2(a,bit) (a>>bit)
#define POW2(a) (1<<a)
#define TEMPLATE template<int DIM, typename MatrixType>
typedef Eigen::VectorXd VectorXd;
/**
 * @brief 枚举类型Smoother预处理平滑方法；
 * 		  WJ 　= 　1:表示加权－雅克比迭代；
 * 		　GS 　= 　2:表示高斯－赛德尔迭代；
 * 		　SOR =　　3:表示超松弛迭代；目前只有这三种可以选择；
 */
enum Smoother{WJ = 1,GS = 2,SOR = 3};
TEMPLATE
class MultigridSolver
{
public:
    typedef MatrixType SpMat;
    typedef Eigen::Triplet<double> Tri; 
    typedef std::function<double(const Point<2, double>&)> Function;
	
protected:
    /**
     * @brief 待求解线性方程组系数矩阵；
     */
    SpMat _A_origin;
    /**
     * @brief 用于存储多重网格计算过程中，所产生的各线性方程组系数矩阵；
     */
    std::vector<SpMat> _A;
    /**
     * @brief 待求解线性方程组右端项向量；
     */
    VectorXd _Rhs_origin;
    /**
     * @brief 计算过程中用于赋值和保存的临时右端向量；
     */
    VectorXd _Rhs_temp;
    /**
     * @brief 用于存储多重网格计算过程中，所产生的各线性方程组右端项向量；
     */
    std::vector<VectorXd> _Rhs;
    /**
     * @brief 用于存储多重网格计算过程中，所产生的线性方程组的解或者参量方程的解；
     */
    std::vector<VectorXd> _u;
    /**
     * @brief 用于记录边界条件函数；
     */
    Function _BCF;
    /**
     * @brief 用于记录当前多重网格求解器所在层数；
     */
    int _nowlevel;
    /**
     * @brief 用于存储多重网格求解器所利用的各层网格的指针；
     */
    std::vector<Mesh<DIM, double>*> _mesh; 
    /**
     * @brief 多重网格最底层所利用的直接求解器，目前使用CG；
     */
    Eigen::ConjugateGradient<MatrixType > Solver_sparse;
    /**
     * @brief 设置迭代终止条件之一:||Au - f||L2 <该值时计算终止；
     */
    double _Tolerance = 1e-6;
    /**
     * @brief 设置迭代终止条件之一:迭代最大步数；
     */
    long _MaxIternum =  20; 
    /**
     * @brief 用于存放所有插值算子，在initialize函数中进行初始化;
     */
    std::vector<SpMat> _P;
    /**
     * @brief 用于存放所有限制算子，在initialize函数中进行初始化;
     */
    std::vector<SpMat> _P_t;
    /**
     * @brief 用于记录最大下降步数；
     */
    int _maxlevel;
    /**
     * @brief 用于设置预处理平滑的方式；默认为SoR;
     */
    Smoother smoothetype;
public:
    /**
     * @brief 求解该线性方程组；
     * 
     * @return VectorXd& 解
     */
    VectorXd& Solve();
    /**
     * @brief Set the Tolerance 
     * 
     * @param tolerance 最大容许误差
     */
    void setTolerance(double tolerance){_Tolerance = tolerance;};
    /**
     * @brief 处理边界条件
     * 
     * @param A 系数矩阵的引用，
     * @param Rhs 右端项的引用，
     */
    void DealWithBoundaryCondition(SpMat &A, VectorXd &Rhs);
    /**
     * @brief 上升一层；
     */
    void UpLevel();
    /**
     * @brief 下降一层；
     */
    void DownLevel();
    /**
     * @brief Vcycle迭代
     */
    void VCycle(int startlevel);
    /**
     * @brief FMG迭代
     */
    //void FMG();
    /**
     * @brief 按照指定方法进行预处理平滑
     * 
     * @param A 
     * @param Rhs 
     * @param s 标志位
     */
    void Smooth(SpMat A, VectorXd Rhs,int s);
    /**
     * @brief 为下一次迭代，将部分变量进行重置；
     */
    void Reset();
    /**
     * @brief 设置待计算的系数矩阵；
     * @param A 
     */
    void compute(SpMat A){_A_origin= A;};
    /**
     * @brief 初始化一些数据结构，例如_mesh,_P以及_P_t;等；
     * 
     * @param A 
     * @param mesh 
     */
    virtual void Initialize(SpMat A, Mesh<DIM, double>* mesh) = 0;
    /**
     * @brief 生成插值和限制算子；
     */
    virtual void GenerateP() = 0;
};

TEMPLATE
void MultigridSolver<DIM,MatrixType>::DealWithBoundaryCondition(SpMat &A, VectorXd &Rhs)
{
    for (auto k:_mesh[_nowlevel-1]->get_all_boundary())
    {
	Point<2, double> bnd_point = _mesh[_nowlevel-1]->get_dof(k);
	double bnd_value = _BCF(bnd_point);
    if(_nowlevel == 1)
	    Rhs[k] = bnd_value * A.coeffRef(k, k);
	for(Eigen::SparseMatrix<double>::InnerIterator it(A,k);it;++it)
        {
            int row = it.row();
            if(row == k)
                continue;
            if(_nowlevel == 1)
                Rhs[row] -= A.coeffRef(k,row)* bnd_value;
            A.coeffRef(k,row) = 0.0;
            A.coeffRef(row,k) = 0.0;
        }
    }
}

TEMPLATE
inline void MultigridSolver<DIM,MatrixType>::UpLevel()
{
    _nowlevel++;
}

TEMPLATE
inline void MultigridSolver<DIM,MatrixType>::DownLevel()
{
    _nowlevel--;
}

TEMPLATE
void MultigridSolver<DIM,MatrixType>::Smooth(SpMat A, VectorXd Rhs,int s)
{
    VectorXd& u = _u[_nowlevel-1];
    VectorXd u_star = Eigen::MatrixXd::Zero(Rhs.size(),1);
    if(smoothetype == 1)
    {
        double w = 1.0;
        for(int i = 0;i < u.size();i++)
        {
            double a = Rhs[i];
            double a_ii = A.coeffRef(i,i);
            if(_mesh[_nowlevel-1]->IsBoundary(i))
            {
                if(s == 1)
                    u_star[i] = Rhs[i] / a_ii;
                else
                    u_star[i] = u[i];
            }
            else
            {
                for(Eigen::SparseMatrix<double>::InnerIterator it(A,i);it;++it)
                {
                    int i_row = it.row();
                    if(i_row == i)
                        continue;
                    else
                        a -=(it.value() * (u[i_row]));
                }
                u_star[i] = a / a_ii;
            }  
        }
        _u[_nowlevel-1] = (1 - w) * u + w * u_star;
    }
    else if(smoothetype == 2)
    {
        for(int i = 0;i < u.size();i++)
        {
            double a = Rhs[i];
            double a_ii = A.coeffRef(i,i);
            if(_mesh[_nowlevel-1]->IsBoundary(i))
            {
                if(s == 1)
                    u[i] = Rhs[i] / a_ii;
            }
            else
            {
                for(Eigen::SparseMatrix<double>::InnerIterator it(A,i);it;++it)
                {
                    int i_row = it.row();
                    if(i_row == i)
                        continue;
                     else
                        a -=(it.value() * (u[i_row]));
                }
                u[i] = a / a_ii;
            }  
        }
    _u[_nowlevel-1] = u;
    }
    else if(smoothetype == 3)
    {
        double w = 1.23;
        for(int i = 0;i < u.size();i++)
        {
            double a = Rhs[i];
            double a_ii = A.coeffRef(i,i);
            if(_mesh[_nowlevel-1]->IsBoundary(i))
            {
                if(s == 1)
                    u[i] = Rhs[i] / a_ii;
            }
            else
            {
                 for(Eigen::SparseMatrix<double>::InnerIterator it(A,i);it;++it)
                {
                    int i_row = it.row();
                    if(i_row == i)
                        continue;
                     else
                        a -=(it.value() * (u[i_row]));
                }
                u[i] = (w) * a / a_ii + (1- w) * u[i];
            }  
        }
        _u[_nowlevel -1] = u;
    }
}

TEMPLATE
void MultigridSolver<DIM,MatrixType>::VCycle(int startlevel)
{
	_nowlevel = startlevel;
    SpMat A_dealt = _A[_nowlevel-1];
    VectorXd Rhs_dealt = _Rhs[_nowlevel - 1];
    clock_t starttime,endtime;
    starttime = clock();
    DealWithBoundaryCondition(A_dealt, Rhs_dealt);
    endtime = clock();
    if (_nowlevel == _maxlevel)
	{
        Solver_sparse.compute(A_dealt);
        _u[_nowlevel - 1] = Solver_sparse.solve(_Rhs[_nowlevel - 1]);
    } 
    else
    {
    starttime = clock();
	for (int i = 0; i < 2; i++)
	    Smooth(A_dealt,Rhs_dealt,_nowlevel);
    endtime = clock();
	starttime = clock();
    VectorXd r_2h = _P_t[_nowlevel-1]*(Rhs_dealt - A_dealt * _u[_nowlevel-1]);
    for(int i = 0; i < r_2h.size(); i++) 
        if(_mesh[_nowlevel]->IsBoundary(i))
            r_2h[i] = 0.0;
    endtime = clock();
 	VectorXd e = Eigen::MatrixXd::Zero(r_2h.size(),1);
	_u.push_back(e);
	_Rhs.push_back(r_2h);
	UpLevel();
	VCycle(_nowlevel);
	DownLevel();
	_u[_nowlevel-1] = _u[_nowlevel-1] + _P[_nowlevel-1] * _u[_nowlevel];
	for (int i = 0; i < 3; i++)
	    Smooth(A_dealt,Rhs_dealt,_nowlevel);
    }
}

TEMPLATE
void MultigridSolver<DIM,MatrixType>::Reset()
{
    _Rhs_temp = _Rhs_origin;
    _Rhs.resize(1);
    _u.resize(1);
}
TEMPLATE
VectorXd& MultigridSolver<DIM,MatrixType>::Solve()
{
    std::cout << "Solve..." << std::endl;
    int times = 0;
    VectorXd residual_vector = _A_origin * Eigen::MatrixXd::Zero(_Rhs[0].size(),1) - _Rhs_origin;
    double residual = residual_vector.norm();
    double residual_max = residual_vector.lpNorm<Eigen::Infinity>();
    do
    {
		VCycle(_nowlevel);
		//FMG();
        Reset();
        SpMat _A_temp = _A_origin;
        VectorXd _Rhs_temp = _Rhs_origin;
        DealWithBoundaryCondition(_A_temp,_Rhs_temp);
        residual_vector = _A_temp * _u[0] - _Rhs_temp;
        residual = residual_vector.norm();
        residual_max = residual_vector.lpNorm<Eigen::Infinity>();
        std::cout << " the "<< times + 1 << "th l2 norm of residual vecror is :" << residual << " and the max norm case is :" << residual_max<<std::endl;
    }while(times++ < 20& residual >= _Tolerance);
    return _u[0];
}
#undef TEMPLATE
#define TEMPLATE template<typename MatrixType>

/***************************************************************/
/*************************** P1 MG *****************************/
/***************************************************************/

/**
 * @brief P1多重网格求解器，由MultigridSolver<2,MatrixType>继承而来；
 * 
 */
TEMPLATE
class P1_MGSolver:public MultigridSolver<2, MatrixType>
{
public:
    typedef MultigridSolver<2,MatrixType> _base;
    typedef typename _base::SpMat SpMat;
    typedef typename _base::Function Function;
    typedef Eigen::Triplet<double> Tri;
    using _base::_A_origin;
    using _base::_A;
    using _base::_Rhs_origin;
    using _base::_Rhs_temp;
    using _base::_Rhs;
    using _base::_u;
    using _base::_BCF;
    using _base::_nowlevel;
    using _base::_mesh;
    using _base::Solver_sparse;
    using _base::_P;
    using _base::_P_t;
    using _base::_maxlevel;
    using _base::_Tolerance;
    using _base::_MaxIternum;
    using _base::smoothetype;
public:
    P1_MGSolver(SpMat A, VectorXd Rhs,Function BCF, Mesh<2, double>* mesh, int maxlevel);
    ~P1_MGSolver()
    {
	for (int i = 0; i < _mesh.size(); i++)
	    delete _mesh[i];
    }
    void Initialize(SpMat A, Mesh<2, double>* mesh);
    void GenerateP();
};

TEMPLATE
P1_MGSolver<MatrixType>::P1_MGSolver(SpMat A, VectorXd Rhs,Function BCF, Mesh<2,double>* mesh, int maxlevel)
{
    _nowlevel = 1;
    _maxlevel = maxlevel;
    _A_origin = A;
    _A.resize(maxlevel);
    _mesh.resize(maxlevel);
    _P.resize(maxlevel - 1);
    _P_t.resize(maxlevel - 1);
    Initialize(A, mesh);
    _Rhs_origin = Rhs;
    _Rhs_temp = Rhs;
    _Rhs.push_back(Rhs);
    _u.push_back(Eigen::MatrixXd::Zero(Rhs.size(),1));
    _BCF = BCF;
    Solver_sparse.setTolerance(_Tolerance);
    smoothetype = GS;
}

TEMPLATE
void P1_MGSolver<MatrixType>::Initialize(SpMat A, Mesh<2,double>* mesh)
{
    std::cout << "Initialize..." << std::endl;
    _A[_nowlevel-1] = A;
    _mesh[_nowlevel-1] = new RegularMesh(mesh->get_lbc(), mesh->get_ruc(), mesh->get_nx(), mesh->get_ny());
	std::cout << "Generate P1 mesh..." << std::endl;
    _mesh[_nowlevel] = new RegularMesh(mesh->get_lbc(), mesh->get_ruc(), Devide2(mesh->get_nx(), 1), Devide2(mesh->get_ny(), 1));
    GenerateP();
    //std::cout << _P[0];
    _nowlevel++;
    while (_nowlevel <= _maxlevel)
    {
	    _A[_nowlevel-1] = _P_t[_nowlevel-2] * _A[_nowlevel-2] * _P[_nowlevel-2];
	    if (_nowlevel != _maxlevel)
	    {
            _mesh[_nowlevel] = new RegularMesh(_mesh[_nowlevel-1]->get_lbc(), _mesh[_nowlevel-1]->get_ruc(), Devide2(_mesh[_nowlevel-1]->get_nx(), 1), Devide2(_mesh[_nowlevel-1]->get_ny(), 1));
            GenerateP();
	    }
	    _nowlevel++;
    }
    _nowlevel = 1;
	std::cout << "Generate P Matrix and P_t Matrix..." <<std::endl;
    std::cout << "Done" << std::endl;
}

TEMPLATE
void P1_MGSolver<MatrixType>::GenerateP()
{
    _P[_nowlevel-1] = SpMat(_mesh[_nowlevel-1]->get_n_dofs(), _mesh[_nowlevel]->get_n_dofs());
    _P_t[_nowlevel-1] = SpMat(_mesh[_nowlevel]->get_n_dofs(), _mesh[_nowlevel-1]->get_n_dofs());
    std::vector<Tri> TriList1(_mesh[_nowlevel]->get_n_dofs()* 7);
    std::vector<Tri> TriList2(_mesh[_nowlevel]->get_n_dofs()* 7);
    std::vector<Tri>::iterator it1 = TriList1.begin();
    std::vector<Tri>::iterator it2 = TriList2.begin();
    for (int i = 0; i < _mesh[_nowlevel]->get_ny()+1; i++)
	for (int j = 0; j < _mesh[_nowlevel]->get_nx()+1; j++)
	{
	    int CoarseIdx = _mesh[_nowlevel]->CortoIdx(i, j);
	    int FineIdx = _mesh[_nowlevel-1]->CortoIdx(2*i, 2*j);
	    *(it1++) = Tri(FineIdx, CoarseIdx, 1.0);
	    *(it2++) = Tri(CoarseIdx, FineIdx, 1.0);
	    
	    // 不是下边界，下方有相关点
	    if (i != 0)
	    {
		(*it1++) = Tri(FineIdx-(_mesh[_nowlevel-1]->get_nx()+1), CoarseIdx, 0.5);
		(*it2++) = Tri(CoarseIdx, FineIdx-(_mesh[_nowlevel-1]->get_nx()+1), 0.5);
		// 同时不是左边界， 左下方有相关点
		if (j != 0)
		{
		    *(it1++) = Tri(FineIdx-(_mesh[_nowlevel-1]->get_nx()+1)-1, CoarseIdx, 0.5);
		    *(it2++) = Tri(CoarseIdx, FineIdx-(_mesh[_nowlevel-1]->get_nx()+1)-1, 0.5);
		}
		// 同时不是右边界，但是按照我们的网格划分，右下方没有相关点
		if (j != _mesh[_nowlevel]->get_nx())
		{
		}
	    }
	    // 不是上边界， 上方有相关点
	    if (i != _mesh[_nowlevel]->get_ny())
	    {
		*(it1++) = Tri(FineIdx+(_mesh[_nowlevel-1]->get_nx()+1), CoarseIdx, 0.5);
		*(it2++) = Tri(CoarseIdx, FineIdx+(_mesh[_nowlevel-1]->get_nx()+1), 0.5);
		// 同时不是左边界，但是按照我们的网格划分，左上方没有相关点
		if (j != 0)
		{
		}
		// 同时不是右边界，右上方有相关点
		if (j != _mesh[_nowlevel]->get_nx())
		{
		    *(it1++) = Tri(FineIdx+(_mesh[_nowlevel-1]->get_nx()+1)+1, CoarseIdx, 0.5);
		    *(it2++) = Tri(CoarseIdx, FineIdx+(_mesh[_nowlevel-1]->get_nx()+1)+1, 0.5);
		}
	    }
	    // 不是左边界，左方有相关点
	    if (j != 0)
	    {
		*(it1++) = Tri(FineIdx-1, CoarseIdx, 0.5);
		*(it2++) = Tri(CoarseIdx, FineIdx-1, 0.5);
	    }
	    //不是右边界，右方有相关点
	    if (j != _mesh[_nowlevel]->get_nx())
	    {
		*(it1++) = Tri(FineIdx+1, CoarseIdx, 0.5);
		*(it2++) = Tri(CoarseIdx, FineIdx+1, 0.5);
	    }
	}
    _P[_nowlevel-1].setFromTriplets(TriList1.begin(), TriList1.end());
    _P[_nowlevel-1].makeCompressed();
    _P_t[_nowlevel-1].setFromTriplets(TriList2.begin(), TriList2.end());
    _P_t[_nowlevel-1].makeCompressed();
}


#undef TEMPLATE
#else
#endif
